#!/usr/bin/env python

# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
from pathlib import Path

optparser = argparse.ArgumentParser(description='Train classification models on ImageNet',
                                    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
optparser.add_argument('-n', '--ngpus', type=int, default=1, help='number of GPUs to use')
optparser.add_argument('-b', '--batch-size', type=int, default=192, help='batch size per GPU')
optparser.add_argument('-e', '--num-epochs', type=int, default=90, help='number of epochs')
optparser.add_argument('-l', '--lr', type=float, default=0.256, help='learning rate; '
                       'IMPORTANT: true learning rate will be calculated as `lr * batch_size / 256`')
optparser.add_argument('--data-root', type=Path, help='Directory with RecordIO data files', default=Path('/data/imagenet/train-val-recordio-passthrough'))
optparser.add_argument('--dtype', help='Precision', default='float16', choices=('float32', 'float16'))
optparser.add_argument('--kv-store', default='horovod', choices=('device', 'horovod'), help='key-value store type')
optparser.add_argument('--data-backend', default='dali-gpu', choices=('dali-gpu', 'dali-cpu', 'mxnet', 'synthetic'), help='data backend')

opts, args = optparser.parse_known_args()

if opts.dtype == 'float16':
    n_ch = str(4 - int(opts.data_backend == 'mxnet'))
else:
    n_ch = str(3)

opts.batch_size *= opts.ngpus
opts.lr *= opts.batch_size / 256

command = []
if 'horovod' in opts.kv_store:
    command += ['horovodrun', '-np', str(opts.ngpus)]
command += ['python', str(Path(__file__).parent / "train.py")]
command += ['--data-train', str(opts.data_root / "train.rec")]
command += ['--data-train-idx', str(opts.data_root / "train.idx")]
command += ['--data-val', str(opts.data_root / "val.rec")]
command += ['--data-val-idx', str(opts.data_root / "val.idx")]
command += ['--dtype', opts.dtype]
command += ['--image-shape', n_ch + ',224,224']
if opts.dtype == 'float16':
    command += '--fuse-bn-relu 1 --fuse-bn-add-relu 1'.split()
    command += '--input-layout NCHW --conv-layout NHWC ' \
               '--batchnorm-layout NHWC --pooling-layout NHWC'.split()

command += ['--kv-store', opts.kv_store]
command += ['--data-backend', opts.data_backend]
command += ['--lr', str(opts.lr)]
command += ['--gpus', ','.join(list(map(str, range(opts.ngpus))))]
command += ['--batch-size', str(opts.batch_size)]
command += ['--num-epochs', str(opts.num_epochs)]

command += args


os.environ['MXNET_UPDATE_ON_KVSTORE'] = "0"
os.environ['MXNET_EXEC_ENABLE_ADDTO'] = "1"
os.environ['MXNET_USE_TENSORRT'] = "0"
os.environ['MXNET_GPU_WORKER_NTHREADS'] = "2"
os.environ['MXNET_GPU_COPY_NTHREADS'] = "1"
os.environ['MXNET_OPTIMIZER_AGGREGATION_SIZE'] = "54"
os.environ['HOROVOD_CYCLE_TIME'] = "0.1"
os.environ['HOROVOD_FUSION_THRESHOLD'] = "67108864"
os.environ['HOROVOD_NUM_NCCL_STREAMS'] = "2"
os.environ['MXNET_HOROVOD_NUM_GROUPS'] = "16"
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = "999"
os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = "25"

os.execvp(command[0], command)
